{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# please set cwd to the root of mmrazor repo."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用MMRazor对ResNet34进行剪枝"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本教程主要介绍如何手动配置剪枝config。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 回顾MMCls"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 跨库调用resnet34配置文件\n",
    "\n",
    "首先我们先跨库调用resnet34的配置文件。通过跨库调用，我们可以继承原有配置文件的所有内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare work_dir\n",
    "work_dir = './demo/tmp/'\n",
    "if not os.path.exists(work_dir):\n",
    "    os.mkdir(work_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mmengine import Config\n",
    "\n",
    "\n",
    "def write_config(config_str, filename):\n",
    "    with open(filename, 'w') as f:\n",
    "        f.write(config_str)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'type': 'ImageClassifier', 'backbone': {'type': 'ResNet', 'depth': 34, 'num_stages': 4, 'out_indices': (3,), 'style': 'pytorch'}, 'neck': {'type': 'GlobalAveragePooling'}, 'head': {'type': 'LinearClsHead', 'num_classes': 1000, 'in_channels': 512, 'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0}, 'topk': (1, 5)}, '_scope_': 'mmcls'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
      "  warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
      "  warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
      "  warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
      "  warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
      "  warnings.warn(f'There is not `Config` define in {model_cfg}')\n"
     ]
    }
   ],
   "source": [
    "# Prepare pretrain config\n",
    "pretrain_config_path = f'{work_dir}/pretrain.py'\n",
    "config_string = \"\"\"\n",
    "_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
    "\"\"\"\n",
    "write_config(config_string, pretrain_config_path)\n",
    "print(Config.fromfile(pretrain_config_path)['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run config\n",
    "! timeout 2 python ./tools/train.py $prune_config_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备剪枝config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 增加pretrained参数\n",
    "2. 将resnet34模型装入剪枝算法wrapper中\n",
    "3. 配置剪枝比例\n",
    "4. 运行"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 增加预训练参数\n",
    "我们将原有的’model‘字段取出，命名为architecture，并且给archtecture增加init_cfg字段用来加载预训练模型参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
    "prune_config_path = work_dir + 'prune.py'\n",
    "config_string += \"\"\"\\n\n",
    "data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
    "architecture = _base_.model\n",
    "architecture.update({\n",
    "    'init_cfg': {\n",
    "        'type':\n",
    "        'Pretrained',\n",
    "        'checkpoint':\n",
    "        'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'  # noqa\n",
    "    }\n",
    "})\n",
    "\"\"\"\n",
    "write_config(config_string, prune_config_path)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 将resnet34模型装入剪枝算法wrapper中\n",
    "\n",
    "我们将原有的model作为architecture放入到ItePruneAlgorithm算法中，并且将ItePruneAlgorithm作为新的model字段。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_string += \"\"\"\n",
    "target_pruning_ratio={}\n",
    "model = dict(\n",
    "    _delete_=True,\n",
    "    _scope_='mmrazor',\n",
    "    type='ItePruneAlgorithm',\n",
    "    architecture=architecture,\n",
    "    mutator_cfg=dict(\n",
    "        type='ChannelMutator',\n",
    "        channel_unit_cfg=dict(\n",
    "            type='L1MutableChannelUnit',\n",
    "            default_args=dict(choice_mode='ratio'))),\n",
    "    target_pruning_ratio=target_pruning_ratio,\n",
    "    step_freq=1,\n",
    "    prune_times=1,\n",
    ")\n",
    "\"\"\"\n",
    "write_config(config_string, prune_config_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "配置到这一步时，我们的config文件已经能够运行了。但是因为我们没有配置target_pruning_ratio，因此现在跑起来就和直接用原有config跑起来没有区别，接下来我们会介绍如何配置剪枝比例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "! timeout 2 python ./tools/train.py $prune_config_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 配置剪枝比例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们的模型使用tracer解析模型，进而获得剪枝节点，为了方便用户配置剪枝节点比例，我们提供了一个获得剪枝节点剪枝比例配置的工具。通过该工具，我们可以方便地对剪枝比例进行配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"backbone.conv1_(0, 64)_64\":1.0,\n",
      "    \"backbone.layer1.0.conv1_(0, 64)_64\":1.0,\n",
      "    \"backbone.layer1.1.conv1_(0, 64)_64\":1.0,\n",
      "    \"backbone.layer1.2.conv1_(0, 64)_64\":1.0,\n",
      "    \"backbone.layer2.0.conv1_(0, 128)_128\":1.0,\n",
      "    \"backbone.layer2.0.conv2_(0, 128)_128\":1.0,\n",
      "    \"backbone.layer2.1.conv1_(0, 128)_128\":1.0,\n",
      "    \"backbone.layer2.2.conv1_(0, 128)_128\":1.0,\n",
      "    \"backbone.layer2.3.conv1_(0, 128)_128\":1.0,\n",
      "    \"backbone.layer3.0.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.0.conv2_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.1.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.2.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.3.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.4.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer3.5.conv1_(0, 256)_256\":1.0,\n",
      "    \"backbone.layer4.0.conv1_(0, 512)_512\":1.0,\n",
      "    \"backbone.layer4.0.conv2_(0, 512)_512\":1.0,\n",
      "    \"backbone.layer4.1.conv1_(0, 512)_512\":1.0,\n",
      "    \"backbone.layer4.2.conv1_(0, 512)_512\":1.0\n",
      "}"
     ]
    }
   ],
   "source": [
    "ratio_template_path=work_dir+'prune_ratio_template.json'\n",
    "! python ./tools/pruning/get_channel_units.py $pretrain_config_path --choice -o $ratio_template_path  &> /dev/null 2>&1\n",
    "! cat $ratio_template_path\n",
    "! rm $ratio_template_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们修改该配置模板如下，并且将替换到我们的剪枝配置文件中。\n",
    "\n",
    "（该配置来源于：Li, Hao, et al. \"Pruning filters for efficient convnets.\" arXiv preprint arXiv:1608.08710 (2016).）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_config = \"\"\"\n",
    "un_prune = 1.0\n",
    "stage_ratio_1 = 0.5\n",
    "stage_ratio_2 = 0.4\n",
    "stage_ratio_3 = 0.6\n",
    "stage_ratio_4 = un_prune\n",
    "\n",
    "target_pruning_ratio = {\n",
    "    # stage 1\n",
    "    'backbone.conv1_(0, 64)_64': un_prune,  # short cut layers\n",
    "    'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
    "    'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
    "    'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
    "    # stage 2\n",
    "    'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
    "    'backbone.layer2.0.conv2_(0, 128)_128': un_prune,  # short cut layers\n",
    "    'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
    "    'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
    "    'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
    "    # stage 3\n",
    "    'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
    "    'backbone.layer3.0.conv2_(0, 256)_256': un_prune,  # short cut layers\n",
    "    'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
    "    'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
    "    'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
    "    'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
    "    'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
    "    # stage 4\n",
    "    'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
    "    'backbone.layer4.0.conv2_(0, 512)_512': un_prune,  # short cut layers\n",
    "    'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
    "    'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
    "}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
      "\n",
      "\n",
      "data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
      "architecture = _base_.model\n",
      "architecture.update({\n",
      "    'init_cfg': {\n",
      "        'type':\n",
      "        'Pretrained',\n",
      "        'checkpoint':\n",
      "        'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'  # noqa\n",
      "    }\n",
      "})\n",
      "\n",
      "\n",
      "un_prune = 1.0\n",
      "stage_ratio_1 = 0.5\n",
      "stage_ratio_2 = 0.4\n",
      "stage_ratio_3 = 0.6\n",
      "stage_ratio_4 = un_prune\n",
      "\n",
      "target_pruning_ratio = {\n",
      "    # stage 1\n",
      "    'backbone.conv1_(0, 64)_64': un_prune,  # short cut layers\n",
      "    'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
      "    'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
      "    'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
      "    # stage 2\n",
      "    'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
      "    'backbone.layer2.0.conv2_(0, 128)_128': un_prune,  # short cut layers\n",
      "    'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
      "    'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
      "    'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
      "    # stage 3\n",
      "    'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
      "    'backbone.layer3.0.conv2_(0, 256)_256': un_prune,  # short cut layers\n",
      "    'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
      "    'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
      "    'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
      "    'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
      "    'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
      "    # stage 4\n",
      "    'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
      "    'backbone.layer4.0.conv2_(0, 512)_512': un_prune,  # short cut layers\n",
      "    'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
      "    'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
      "}\n",
      "\n",
      "model = dict(\n",
      "    _delete_=True,\n",
      "    _scope_='mmrazor',\n",
      "    type='ItePruneAlgorithm',\n",
      "    architecture=architecture,\n",
      "    mutator_cfg=dict(\n",
      "        type='ChannelMutator',\n",
      "        channel_unit_cfg=dict(\n",
      "            type='L1MutableChannelUnit',\n",
      "            default_args=dict(choice_mode='ratio'))),\n",
      "    target_pruning_ratio=target_pruning_ratio,\n",
      "    step_freq=1,\n",
      "    prune_times=1,\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "config_string=config_string.replace('target_pruning_ratio={}',target_config)\n",
    "write_config(config_string,prune_config_path)\n",
    "! cat $prune_config_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. 运行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "! timeout 2 python ./tools/train.py $prune_config_path"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自动生成剪枝Config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们提供了一键生成剪枝config的工具get_prune_config.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "usage: get_l1_prune_config.py [-h] [--checkpoint CHECKPOINT] [--subnet SUBNET]\n",
      "                              [-o O]\n",
      "                              config\n",
      "\n",
      "Get the config to prune a model.\n",
      "\n",
      "positional arguments:\n",
      "  config                config of the model\n",
      "\n",
      "optional arguments:\n",
      "  -h, --help            show this help message and exit\n",
      "  --checkpoint CHECKPOINT\n",
      "                        checkpoint path of the model\n",
      "  --subnet SUBNET       pruning structure for the model\n",
      "  -o O                  output path to store the pruning config.\n"
     ]
    }
   ],
   "source": [
    "! python ./tools/pruning/get_l1_prune_config.py -h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model = dict(\n",
      "    _scope_='mmrazor',\n",
      "    type='ItePruneAlgorithm',\n",
      "    architecture=dict(\n",
      "        type='ImageClassifier',\n",
      "        backbone=dict(\n",
      "            type='ResNet',\n",
      "            depth=34,\n",
      "            num_stages=4,\n",
      "            out_indices=(3, ),\n",
      "            style='pytorch'),\n",
      "        neck=dict(type='GlobalAveragePooling'),\n",
      "        head=dict(\n",
      "            type='LinearClsHead',\n",
      "            num_classes=1000,\n",
      "            in_channels=512,\n",
      "            loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
      "            topk=(1, 5)),\n",
      "        _scope_='mmcls',\n",
      "        init_cfg=dict(\n",
      "            type='Pretrained',\n",
      "            checkpoint=\n",
      "            'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
      "        ),\n",
      "        data_preprocessor=dict(\n",
      "            mean=[123.675, 116.28, 103.53],\n",
      "            std=[58.395, 57.12, 57.375],\n",
      "            to_rgb=True)),\n",
      "    target_pruning_ratio=dict({\n",
      "        'backbone.conv1_(0, 64)_64': 1.0,\n",
      "        'backbone.layer1.0.conv1_(0, 64)_64': 1.0,\n",
      "        'backbone.layer1.1.conv1_(0, 64)_64': 1.0,\n",
      "        'backbone.layer1.2.conv1_(0, 64)_64': 1.0,\n",
      "        'backbone.layer2.0.conv1_(0, 128)_128': 1.0,\n",
      "        'backbone.layer2.0.conv2_(0, 128)_128': 1.0,\n",
      "        'backbone.layer2.1.conv1_(0, 128)_128': 1.0,\n",
      "        'backbone.layer2.2.conv1_(0, 128)_128': 1.0,\n",
      "        'backbone.layer2.3.conv1_(0, 128)_128': 1.0,\n",
      "        'backbone.layer3.0.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.0.conv2_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.1.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.2.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.3.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.4.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer3.5.conv1_(0, 256)_256': 1.0,\n",
      "        'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n",
      "        'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n",
      "        'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n",
      "        'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n",
      "    }),\n",
      "    mutator_cfg=dict(\n",
      "        type='ChannelMutator',\n",
      "        channel_unit_cfg=dict(\n",
      "            type='L1MutableChannelUnit',\n",
      "            default_args=dict(choice_mode='ratio')),\n",
      "        parse_cfg=dict(\n",
      "            type='ChannelAnalyzer',\n",
      "            tracer_type='FxTracer',\n",
      "            demo_input=dict(type='DefaultDemoInput', scope='mmcls'))))\n",
      "dataset_type = 'ImageNet'\n",
      "data_preprocessor = None\n",
      "train_pipeline = [\n",
      "    dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
      "    dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n",
      "    dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n",
      "    dict(type='PackClsInputs', _scope_='mmcls')\n",
      "]\n",
      "test_pipeline = [\n",
      "    dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
      "    dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n",
      "    dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n",
      "    dict(type='PackClsInputs', _scope_='mmcls')\n",
      "]\n",
      "train_dataloader = dict(\n",
      "    batch_size=32,\n",
      "    num_workers=5,\n",
      "    dataset=dict(\n",
      "        type='ImageNet',\n",
      "        data_root='data/imagenet',\n",
      "        ann_file='meta/train.txt',\n",
      "        data_prefix='train',\n",
      "        pipeline=[\n",
      "            dict(type='LoadImageFromFile'),\n",
      "            dict(type='RandomResizedCrop', scale=224),\n",
      "            dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n",
      "            dict(type='PackClsInputs')\n",
      "        ],\n",
      "        _scope_='mmcls'),\n",
      "    sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n",
      "    persistent_workers=True)\n",
      "val_dataloader = dict(\n",
      "    batch_size=32,\n",
      "    num_workers=5,\n",
      "    dataset=dict(\n",
      "        type='ImageNet',\n",
      "        data_root='data/imagenet',\n",
      "        ann_file='meta/val.txt',\n",
      "        data_prefix='val',\n",
      "        pipeline=[\n",
      "            dict(type='LoadImageFromFile'),\n",
      "            dict(type='ResizeEdge', scale=256, edge='short'),\n",
      "            dict(type='CenterCrop', crop_size=224),\n",
      "            dict(type='PackClsInputs')\n",
      "        ],\n",
      "        _scope_='mmcls'),\n",
      "    sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
      "    persistent_workers=True)\n",
      "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
      "test_dataloader = dict(\n",
      "    batch_size=32,\n",
      "    num_workers=5,\n",
      "    dataset=dict(\n",
      "        type='ImageNet',\n",
      "        data_root='data/imagenet',\n",
      "        ann_file='meta/val.txt',\n",
      "        data_prefix='val',\n",
      "        pipeline=[\n",
      "            dict(type='LoadImageFromFile'),\n",
      "            dict(type='ResizeEdge', scale=256, edge='short'),\n",
      "            dict(type='CenterCrop', crop_size=224),\n",
      "            dict(type='PackClsInputs')\n",
      "        ],\n",
      "        _scope_='mmcls'),\n",
      "    sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
      "    persistent_workers=True)\n",
      "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
      "optim_wrapper = dict(\n",
      "    optimizer=dict(\n",
      "        type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n",
      "        _scope_='mmcls'))\n",
      "param_scheduler = dict(\n",
      "    type='MultiStepLR',\n",
      "    by_epoch=True,\n",
      "    milestones=[30, 60, 90],\n",
      "    gamma=0.1,\n",
      "    _scope_='mmcls')\n",
      "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n",
      "val_cfg = dict()\n",
      "test_cfg = dict()\n",
      "auto_scale_lr = dict(base_batch_size=256)\n",
      "default_scope = 'mmcls'\n",
      "default_hooks = dict(\n",
      "    timer=dict(type='IterTimerHook', _scope_='mmcls'),\n",
      "    logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n",
      "    param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n",
      "    checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n",
      "    sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n",
      "    visualization=dict(\n",
      "        type='VisualizationHook', enable=False, _scope_='mmcls'))\n",
      "env_cfg = dict(\n",
      "    cudnn_benchmark=False,\n",
      "    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n",
      "    dist_cfg=dict(backend='nccl'))\n",
      "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n",
      "visualizer = dict(\n",
      "    type='ClsVisualizer',\n",
      "    vis_backends=[dict(type='LocalVisBackend')],\n",
      "    _scope_='mmcls')\n",
      "log_level = 'INFO'\n",
      "load_from = None\n",
      "resume = False\n"
     ]
    }
   ],
   "source": [
    "! python ./tools/pruning/get_l1_prune_config.py $work_dir/pretrain.py  --checkpoint $checkpoint_path  -o $prune_config_path   &> /dev/null\n",
    "! cat $prune_config_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 清理临时文件\n",
    "! rm -r $work_dir"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('lab2max')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
